{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Introduction to TorchScript\n\n*James Reed (jamesreed@fb.com), Michael Suo (suo@fb.com)*, rev2\n\nThis tutorial is an introduction to TorchScript, an intermediate\nrepresentation of a PyTorch model (subclass of ``nn.Module``) that\ncan then be run in a high-performance environment such as C++.\n\nIn this tutorial we will cover:\n\n1. The basics of model authoring in PyTorch, including:\n\n-  Modules\n-  Defining ``forward`` functions\n-  Composing modules into a hierarchy of modules\n\n2. Specific methods for converting PyTorch modules to TorchScript, our\n   high-performance deployment runtime\n\n-  Tracing an existing module\n-  Using scripting to directly compile a module\n-  How to compose both approaches\n-  Saving and loading TorchScript modules\n\nWe hope that after you complete this tutorial, you will proceed to go through\n[the follow-on tutorial](https://pytorch.org/tutorials/advanced/cpp_export.html)\nwhich will walk you through an example of actually calling a TorchScript\nmodel from C++.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import torch  # This is all you need to use both PyTorch and TorchScript!\nprint(torch.__version__)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Basics of PyTorch Model Authoring\n\nLet\u2019s start out by defining a simple ``Module``. A ``Module`` is the\nbasic unit of composition in PyTorch. It contains:\n\n1. A constructor, which prepares the module for invocation\n2. A set of ``Parameters`` and sub-\\ ``Modules``. These are initialized\n   by the constructor and can be used by the module during invocation.\n3. A ``forward`` function. This is the code that is run when the module\n   is invoked.\n\nLet\u2019s examine a small example:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyCell(torch.nn.Module):\n    def __init__(self):\n        super(MyCell, self).__init__()\n\n    def forward(self, x, h):\n        new_h = torch.tanh(x + h)\n        return new_h, new_h\n\nmy_cell = MyCell()\nx = torch.rand(3, 4)\nh = torch.rand(3, 4)\nprint(my_cell(x, h))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "So we\u2019ve:\n\n1. Created a class that subclasses ``torch.nn.Module``.\n2. Defined a constructor. The constructor doesn\u2019t do much, just calls\n   the constructor for ``super``.\n3. Defined a ``forward`` function, which takes two inputs and returns\n   two outputs. The actual contents of the ``forward`` function are not\n   really important, but it\u2019s sort of a fake [RNN\n   cell](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)_\u2013that\n   is\u2013it\u2019s a function that is applied on a loop.\n\nWe instantiated the module, and made ``x`` and ``h``, which are just 3x4\nmatrices of random values. Then we invoked the cell with\n``my_cell(x, h)``. This in turn calls our ``forward`` function.\n\nLet\u2019s do something a little more interesting:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyCell(torch.nn.Module):\n    def __init__(self):\n        super(MyCell, self).__init__()\n        self.linear = torch.nn.Linear(4, 4)\n\n    def forward(self, x, h):\n        new_h = torch.tanh(self.linear(x) + h)\n        return new_h, new_h\n\nmy_cell = MyCell()\nprint(my_cell)\nprint(my_cell(x, h))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We\u2019ve redefined our module ``MyCell``, but this time we\u2019ve added a\n``self.linear`` attribute, and we invoke ``self.linear`` in the forward\nfunction.\n\nWhat exactly is happening here? ``torch.nn.Linear`` is a ``Module`` from\nthe PyTorch standard library. Just like ``MyCell``, it can be invoked\nusing the call syntax. We are building a hierarchy of ``Module``\\ s.\n\n``print`` on a ``Module`` will give a visual representation of the\n``Module``\\ \u2019s subclass hierarchy. In our example, we can see our\n``Linear`` subclass and its parameters.\n\nBy composing ``Module``\\ s in this way, we can succinctly and readably\nauthor models with reusable components.\n\nYou may have noticed ``grad_fn`` on the outputs. This is a detail of\nPyTorch\u2019s method of automatic differentiation, called\n[autograd](https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html)_.\nIn short, this system allows us to compute derivatives through\npotentially complex programs. The design allows for a massive amount of\nflexibility in model authoring.\n\nNow let\u2019s examine said flexibility:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyDecisionGate(torch.nn.Module):\n    def forward(self, x):\n        if x.sum() > 0:\n            return x\n        else:\n            return -x\n\nclass MyCell(torch.nn.Module):\n    def __init__(self):\n        super(MyCell, self).__init__()\n        self.dg = MyDecisionGate()\n        self.linear = torch.nn.Linear(4, 4)\n\n    def forward(self, x, h):\n        new_h = torch.tanh(self.dg(self.linear(x)) + h)\n        return new_h, new_h\n\nmy_cell = MyCell()\nprint(my_cell)\nprint(my_cell(x, h))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We\u2019ve once again redefined our MyCell class, but here we\u2019ve defined\n``MyDecisionGate``. This module utilizes **control flow**. Control flow\nconsists of things like loops and ``if``-statements.\n\nMany frameworks take the approach of computing symbolic derivatives\ngiven a full program representation. However, in PyTorch, we use a\ngradient tape. We record operations as they occur, and replay them\nbackwards in computing derivatives. In this way, the framework does not\nhave to explicitly define derivatives for all constructs in the\nlanguage.\n\n.. figure:: https://github.com/pytorch/pytorch/raw/master/docs/source/_static/img/dynamic_graph.gif\n   :alt: How autograd works\n\n   How autograd works\n\n\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Basics of TorchScript\n\nNow let\u2019s take our running example and see how we can apply TorchScript.\n\nIn short, TorchScript provides tools to capture the definition of your\nmodel, even in light of the flexible and dynamic nature of PyTorch.\nLet\u2019s begin by examining what we call **tracing**.\n\n### Tracing ``Modules``\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyCell(torch.nn.Module):\n    def __init__(self):\n        super(MyCell, self).__init__()\n        self.linear = torch.nn.Linear(4, 4)\n\n    def forward(self, x, h):\n        new_h = torch.tanh(self.linear(x) + h)\n        return new_h, new_h\n\nmy_cell = MyCell()\nx, h = torch.rand(3, 4), torch.rand(3, 4)\ntraced_cell = torch.jit.trace(my_cell, (x, h))\nprint(traced_cell)\ntraced_cell(x, h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We\u2019ve rewinded a bit and taken the second version of our ``MyCell``\nclass. As before, we\u2019ve instantiated it, but this time, we\u2019ve called\n``torch.jit.trace``, passed in the ``Module``, and passed in *example\ninputs* the network might see.\n\nWhat exactly has this done? It has invoked the ``Module``, recorded the\noperations that occured when the ``Module`` was run, and created an\ninstance of ``torch.jit.ScriptModule`` (of which ``TracedModule`` is an\ninstance)\n\nTorchScript records its definitions in an Intermediate Representation\n(or IR), commonly referred to in Deep learning as a *graph*. We can\nexamine the graph with the ``.graph`` property:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(traced_cell.graph)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "However, this is a very low-level representation and most of the\ninformation contained in the graph is not useful for end users. Instead,\nwe can use the ``.code`` property to give a Python-syntax interpretation\nof the code:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(traced_cell.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "So **why** did we do all this? There are several reasons:\n\n1. TorchScript code can be invoked in its own interpreter, which is\n   basically a restricted Python interpreter. This interpreter does not\n   acquire the Global Interpreter Lock, and so many requests can be\n   processed on the same instance simultaneously.\n2. This format allows us to save the whole model to disk and load it\n   into another environment, such as in a server written in a language\n   other than Python\n3. TorchScript gives us a representation in which we can do compiler\n   optimizations on the code to provide more efficient execution\n4. TorchScript allows us to interface with many backend/device runtimes\n   that require a broader view of the program than individual operators.\n\nWe can see that invoking ``traced_cell`` produces the same results as\nthe Python module:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "print(my_cell(x, h))\nprint(traced_cell(x, h))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Using Scripting to Convert Modules\n\nThere\u2019s a reason we used version two of our module, and not the one with\nthe control-flow-laden submodule. Let\u2019s examine that now:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyDecisionGate(torch.nn.Module):\n    def forward(self, x):\n        if x.sum() > 0:\n            return x\n        else:\n            return -x\n\nclass MyCell(torch.nn.Module):\n    def __init__(self, dg):\n        super(MyCell, self).__init__()\n        self.dg = dg\n        self.linear = torch.nn.Linear(4, 4)\n\n    def forward(self, x, h):\n        new_h = torch.tanh(self.dg(self.linear(x)) + h)\n        return new_h, new_h\n\nmy_cell = MyCell(MyDecisionGate())\ntraced_cell = torch.jit.trace(my_cell, (x, h))\n\nprint(traced_cell.dg.code)\nprint(traced_cell.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Looking at the ``.code`` output, we can see that the ``if-else`` branch\nis nowhere to be found! Why? Tracing does exactly what we said it would:\nrun the code, record the operations *that happen* and construct a\nScriptModule that does exactly that. Unfortunately, things like control\nflow are erased.\n\nHow can we faithfully represent this module in TorchScript? We provide a\n**script compiler**, which does direct analysis of your Python source\ncode to transform it into TorchScript. Let\u2019s convert ``MyDecisionGate``\nusing the script compiler:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "scripted_gate = torch.jit.script(MyDecisionGate())\n\nmy_cell = MyCell(scripted_gate)\nscripted_cell = torch.jit.script(my_cell)\n\nprint(scripted_gate.code)\nprint(scripted_cell.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Hooray! We\u2019ve now faithfully captured the behavior of our program in\nTorchScript. Let\u2019s now try running the program:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# New inputs\nx, h = torch.rand(3, 4), torch.rand(3, 4)\ntraced_cell(x, h)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "### Mixing Scripting and Tracing\n\nSome situations call for using tracing rather than scripting (e.g.\u00a0a\nmodule has many architectural decisions that are made based on constant\nPython values that we would like to not appear in TorchScript). In this\ncase, scripting can be composed with tracing: ``torch.jit.script`` will\ninline the code for a traced module, and tracing will inline the code\nfor a scripted module.\n\nAn example of the first case:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class MyRNNLoop(torch.nn.Module):\n    def __init__(self):\n        super(MyRNNLoop, self).__init__()\n        self.cell = torch.jit.trace(MyCell(scripted_gate), (x, h))\n\n    def forward(self, xs):\n        h, y = torch.zeros(3, 4), torch.zeros(3, 4)\n        for i in range(xs.size(0)):\n            y, h = self.cell(xs[i], h)\n        return y, h\n\nrnn_loop = torch.jit.script(MyRNNLoop())\nprint(rnn_loop.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "And an example of the second case:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class WrapRNN(torch.nn.Module):\n    def __init__(self):\n        super(WrapRNN, self).__init__()\n        self.loop = torch.jit.script(MyRNNLoop())\n\n    def forward(self, xs):\n        y, h = self.loop(xs)\n        return torch.relu(y)\n\ntraced = torch.jit.trace(WrapRNN(), (torch.rand(10, 3, 4)))\nprint(traced.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This way, scripting and tracing can be used when the situation calls for\neach of them and used together.\n\n## Saving and Loading models\n\nWe provide APIs to save and load TorchScript modules to/from disk in an\narchive format. This format includes code, parameters, attributes, and\ndebug information, meaning that the archive is a freestanding\nrepresentation of the model that can be loaded in an entirely separate\nprocess. Let\u2019s save and load our wrapped RNN module:\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "traced.save('wrapped_rnn.pt')\n\nloaded = torch.jit.load('wrapped_rnn.pt')\n\nprint(loaded)\nprint(loaded.code)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "As you can see, serialization preserves the module hierarchy and the\ncode we\u2019ve been examining throughout. The model can also be loaded, for\nexample, [into\nC++](https://pytorch.org/tutorials/advanced/cpp_export.html)_ for\npython-free execution.\n\n### Further Reading\n\nWe\u2019ve completed our tutorial! For a more involved demonstration, check\nout the NeurIPS demo for converting machine translation models using\nTorchScript:\nhttps://colab.research.google.com/drive/1HiICg6jRkBnr5hvK2-VnMi88Vi9pUzEJ\n\n\n"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}